// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <deque>

#include "perf_model/graph.hpp"

namespace tt::perf_model
{

class Simulator;
class Buffer;
class SimCache;
struct SimState;
using SimCacheP = std::unique_ptr<SimCache>;
using SimStateP = std::unique_ptr<SimState>;

// Amount / timestamp pair
struct TimeData
{
    std::uint32_t count;
    std::uint32_t timestamp;
};

// Return structure from event process() calls
class DataEvent;
struct ProcessStatus
{
    std::vector<Buffer *> stall_reason;      // Buffers we're stalled on - run this event again when this buffer changes
    std::vector<Buffer *> modified_buffers;  // Buffers mofidied by the processing step
    std::vector<DataEvent *> new_events;     // new events generated by the processing step
};

// Base class for data events
class DataEvent
{
   protected:
    //
    // Set at creation
    //
    std::uint32_t input_index;  // input index in a microbatch
    TimeData data;              // count and time at which this happened
    Buffer *buffer;             // buffer in which data was produced or pushed to

    //
    // Simulation-time
    //
    bool unprocessed;  // set if this event has never been processed

   public:
    DataEvent(std::uint32_t input_index, TimeData data, Buffer *buffer) :
        input_index(input_index), data(data), buffer(buffer), unprocessed(true)
    {
    }

    virtual ~DataEvent() {}

    std::uint32_t timestamp() const { return data.timestamp; }
    Buffer *get_buffer() const { return buffer; }
    bool is_unprocessed() const { return unprocessed; }
    std::uint32_t get_input_index() const { return input_index; }

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) = 0;

    virtual std::string to_string() const = 0;

    bool operator<(const DataEvent &other) const { return data.timestamp < other.data.timestamp; }
};

// Data received by the output buffer. If receiving buffers have room, they will receive data, creating an
// InputDataEvent
class OutputDataEvent : public DataEvent
{
    //
    // Set at creation
    //
    std::vector<std::pair<Buffer *, std::uint32_t>> consumers;  // consuming node + operand pair

    //
    // Set during simulation
    //
    std::vector<std::vector<TimeData>> consumed;  // timestamp/amount pair of consumed data, per consumer
    std::vector<std::uint32_t> remaining;         // remaining data to send, per consumer

   public:
    OutputDataEvent(
        std::uint32_t input_index,
        TimeData data,
        Buffer *output_buffer,
        const std::vector<std::pair<Buffer *, std::uint32_t>> &consumers);

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;
};

// Data received by input buffer. If attached node is ready, it'll go on to produce data in the output buffer.
class InputDataEvent : public DataEvent
{
    //
    // Set at creation
    //

    //
    // Set during simulation
    //
    std::vector<TimeData> consumed;  // amount/timestamp consumed by the op
    // std::uint32_t remaining_amount;  // unconsumed amount
    // std::uint32_t completed_timestamp;

   public:
    InputDataEvent(std::uint32_t input_index, TimeData data, Buffer *receiver) : DataEvent(input_index, data, receiver)
    {
    }

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;
};

// Operation execution
class OpDataEvent : public DataEvent
{
    //
    // Set at creation
    //
    NodeP op;
    std::uint32_t current_t;  // Op only executes one T, and schedules the next one
    std::uint32_t total_t;

    std::uint32_t current_ublock; // For ops that produce ublock by ublock (all by matmul)
    std::uint32_t total_ublocks;

    // Specific to matmuls, multiple inner loops needed before output is created
    std::uint32_t total_k = 1;
    std::uint32_t current_k;

    //
    // Simulation
    //

   public:
    OpDataEvent(
        std::uint32_t input_index,
        TimeData data,
        Buffer *output_buffer,
        std::uint32_t current_t,
        std::uint32_t current_ublock,
        std::uint32_t current_k);

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;
};

// Mid-graph queue
class QueueDataEvent : public DataEvent
{
   public:
    QueueDataEvent(
        std::uint32_t input_index,
        TimeData data,
        Buffer *input_buffer) : DataEvent(input_index, data, input_buffer) {}

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;

};

// Host write to input
class HostWriteDataEvent : public DataEvent
{
   public:
    HostWriteDataEvent(std::uint32_t input_index, TimeData data, Buffer *device_buffer) :
        DataEvent(input_index, data, device_buffer)
    {
    }

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;
};

// Host read from output
class HostReadDataEvent : public DataEvent
{
   public:
    HostReadDataEvent(std::uint32_t input_index, TimeData data, Buffer *device_buffer) :
        DataEvent(input_index, data, device_buffer)
    {
    }

    // Process this event. Return pointer to buffer on which we're stalled, if stalled...
    // If any new events have been generated, populate them in new_events vector.
    virtual ProcessStatus process(SimStateP &sim_state, SimCacheP &cache, std::string const& arch_name) override;

    virtual std::string to_string() const override;
};

}  // namespace tt::perf_model
